{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) 2021, salesforce.com, inc.\\\n",
    "All rights reserved.\\\n",
    "SPDX-License-Identifier: BSD-3-Clause\\\n",
    "For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get started quickly with end-to-end multi-agent RL using WarpDrive! This shows a basic example to create a simple multi-agent Tag environment and get training. For more configuration options and indepth explanations, check out the other tutorials and source code."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Try this notebook on [Colab](http://colab.research.google.com/github/salesforce/warp-drive/blob/master/tutorials/simple-end-to-end-example.ipynb)!**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ⚠️ PLEASE NOTE:\n",
    "This notebook runs on a GPU runtime.\\\n",
    "If running on Colab, choose Runtime > Change runtime type from the menu, then select 'GPU' in the dropdown."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dependencies"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can install the warp_drive package using\n",
    "\n",
    "- the pip package manager, OR\n",
    "- by cloning the warp_drive package and installing the requirements.\n",
    "\n",
    "On Colab, we will do the latter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "IN_COLAB = 'google.colab' in sys.modules\n",
    "\n",
    "if IN_COLAB:\n",
    "    ! git clone https://github.com/salesforce/warp-drive.git \n",
    "    % cd warp-drive\n",
    "    ! pip install -e .\n",
    "else:\n",
    "    ! pip install rl_warp_drive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from example_envs.tag_continuous.tag_continuous import TagContinuous\n",
    "from warp_drive.env_wrapper import EnvWrapper\n",
    "from warp_drive.training.trainer import Trainer\n",
    "from warp_drive.training.utils.data_loader import create_and_push_data_placeholders\n",
    "\n",
    "pytorch_cuda_init_success = torch.cuda.FloatTensor(8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Environment, Training, and Model Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_config = dict(\n",
    "    name = \"tag_continuous\",\n",
    "    \n",
    "    # Environment settings\n",
    "    env = dict(  \n",
    "        num_taggers = 5,\n",
    "        num_runners = 20,\n",
    "        episode_length = 100,\n",
    "        seed = 1234,\n",
    "        use_full_observation = False,\n",
    "        num_other_agents_observed = 10,\n",
    "        tagging_distance = 0.02,\n",
    "    ),\n",
    "\n",
    "    # Trainer settings\n",
    "    trainer = dict(\n",
    "        num_envs = 100,  # Number of environment replicas (numbre of GPU blocks used)\n",
    "        train_batch_size = 10000,  # total batch size used for training per iteration (across all the environments)\n",
    "        num_episodes = 5000,  # Total number of episodes to run the training for (can be arbitrarily high!)\n",
    "        algorithm = \"A2C\",  # trainer algorithm\n",
    "        vf_loss_coeff = 1,  # loss coefficient for the value function loss\n",
    "        entropy_coeff = 0.05,  # coefficient for the entropy component of the loss\n",
    "        clip_grad_norm = True,  # fla indicating whether to clip the gradient norm or not\n",
    "        max_grad_norm = 0.5,  # when clip_grad_norm is True, the clip level\n",
    "        normalize_advantage = False, # flag indicating whether to normalize advantage or not\n",
    "        normalize_return = False # flag indicating whether to normalize return or not\n",
    "    ), \n",
    "    \n",
    "    # Policy network settings\n",
    "    policy =  dict(\n",
    "        runner = dict(\n",
    "            to_train = True,\n",
    "            name = \"fully_connected\",\n",
    "            gamma = 0.98,  # discount rate gamms\n",
    "            lr = 0.005,  # learning rate\n",
    "            model = dict(     \n",
    "                fc_dims = [256, 256],  # dimension(s) of the fully connected layers as a list\n",
    "                model_ckpt_filepath = \"\"  # load model parameters from a saved checkpoint (if specified)\n",
    "            )\n",
    "        ),\n",
    "        tagger = dict(\n",
    "            to_train = True,\n",
    "            name = \"fully_connected\",\n",
    "            gamma = 0.98,\n",
    "            lr = 0.002,\n",
    "            model = dict(\n",
    "                fc_dims = [256, 256],\n",
    "                model_ckpt_filepath = \"\"\n",
    "            )\n",
    "        )\n",
    "    ),\n",
    "    \n",
    "    # Checkpoint saving setting\n",
    "    saving = dict(\n",
    "        print_metrics_freq = 10,  # How often (in iterations) to print the metrics\n",
    "        save_model_params_freq = 5000,  # How often (in iterations) to save the model parameters\n",
    "        basedir = \"/tmp\",  # base folder used for saving\n",
    "        tag = \"continuous-tag-experiment\",\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# End-to-End Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a wrapped environment object via the EnvWrapper\n",
    "# Ensure that use_cuda is set to True (in order to run on the GPU)\n",
    "env_wrapper = EnvWrapper(\n",
    "    TagContinuous(**run_config[\"env\"]),\n",
    "    num_envs=run_config[\"trainer\"][\"num_envs\"], \n",
    "    use_cuda=True\n",
    ")\n",
    "\n",
    "# Agents can share policy models: this dictionary maps policy model names to agent ids.\n",
    "policy_tag_to_agent_id_map = {\n",
    "    \"tagger\": list(env_wrapper.env.taggers),\n",
    "    \"runner\": list(env_wrapper.env.runners),\n",
    "}\n",
    "\n",
    "# Create the trainer object\n",
    "trainer = Trainer(\n",
    "    env_wrapper=env_wrapper,\n",
    "    config=run_config,\n",
    "    policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,\n",
    ")\n",
    "\n",
    "# Create and push data placeholders to the device\n",
    "create_and_push_data_placeholders(\n",
    "    env_wrapper, \n",
    "    policy_tag_to_agent_id_map, \n",
    "    trainer\n",
    ")\n",
    "\n",
    "# Perform training!\n",
    "trainer.train()\n",
    "\n",
    "# Shut off gracefully\n",
    "trainer.graceful_close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learn more and explore our tutorials"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To learn more about WarpDrive, take a look at these tutorials\n",
    "- [WarpDrive basics](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-1-warp_drive_basics.ipynb)\n",
    "- [WarpDrive sampler](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-2-warp_drive_sampler.ipynb)\n",
    "- [WarpDrive reset and log](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-3-warp_drive_reset_and_log.ipynb)\n",
    "- [Creating custom environments](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-4-create_custom_environments.ipynb)\n",
    "- [Training with WarpDrive](https://www.github.com/salesforce/warp-drive/blob/master/tutorials/tutorial-5-training_with_warp_drive.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
